"""
The code is released exclusively for review purposes with the following terms:
PROPRIETARY AND CONFIDENTIAL. UNAUTHORIZED USE, COPYING, OR DISTRIBUTION OF THE 
CODE, VIA ANY MEDIUM, IS STRICTLY PROHIBITED. BY ACCESSING THE CODE, THE 
REVIEWERS AGREE TO DELETE THEM FROM ALL MEDIA AFTER THE REVIEW PERIOD IS OVER.
"""

""" Train a black box model from the training partition """
import numpy as np
import sys
sys.path.append("../utilities")
import os

from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
from sklearn.linear_model import LinearRegression
import joblib
import yaml
from utils import fname_data, fname_model, create_dir_if_not_exist
import pickle

# Pass arguments and run the code
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--config_fname")
parser.add_argument("--dataset_key")
parser.add_argument("--model_key")
args = parser.parse_args()

# Load the config file
config = yaml.load(open(
            os.path.join("config", args.config_fname)),
            Loader=yaml.FullLoader)

# Import the necessary data
datafname = os.path.join("data", args.dataset_key, "input",
                fname_data(config, args.dataset_key)+".pkl")
((X_train, X_test, y_train, y_test, w_train, w_test),
            categorical_feature_names, numerical_feature_names,
            categorical_feature_inds, numerical_feature_inds,
            colnames_onehot, colnames_orig)  = pickle.load(open(datafname, "rb"))

# create the model
if args.model_key == "RFR":
        bb_model = RandomForestRegressor(**config["Bb_Model"][args.model_key])
elif args.model_key == "RFC":
        bb_model = RandomForestClassifier(**config["Bb_Model"][args.model_key])
elif args.model_key == "LinReg":
        bb_model = LinearRegression(**config["Bb_Model"][args.model_key])

bb_model.fit(X_train, y_train, w_train)
print(bb_model.score(X_test, y_test, w_test))

# dump the model
modelfname = fname_model(config, args.model_key, 
                        args.dataset_key)+".pkl"
dirname = os.path.join("data", args.dataset_key, "models")
create_dir_if_not_exist(dirname)
joblib.dump(bb_model, open(os.path.join(dirname, modelfname), "wb"))



